#include "yolov8.h"
#include "decode_yolov8.h"
#include <opencv2/imgproc.hpp>
#include <cuda_runtime.h>

namespace yolo
{

    YOLOV8::YOLOV8(const YoloParams& param) :yolo::YOLO(param)
    {
    }

    YOLOV8::~YOLOV8()
    {
        checkRuntime(cudaFree(m_output_src_transpose_device));
    }

    bool YOLOV8::init(const std::string& mdl_file)
    {
        // 1. init engine & context
        this->load_model(mdl_file);
        // context
        this->m_context = std::unique_ptr<nvinfer1::IExecutionContext>(this->m_engine->createExecutionContext());
        if (this->m_context == nullptr)
        {
            return false;
        }
        // binding dim
        if (m_param.dynamic_batch) // for some models only support static dynamic batch. eg: yolox
        {
            this->m_context->setBindingDimensions(0, nvinfer1::Dims4(m_param.batch_size, 3, m_param.dst_h, m_param.dst_w));
        }

        // 2. get output's dim
        m_output_dims = this->m_context->getBindingDimensions(1);
        m_total_objects = m_output_dims.d[2];
        assert(m_param.batch_size <= m_output_dims.d[0]);
        m_output_area = 1; // 22500 * 85
        for (int i = 1; i < m_output_dims.nbDims; i++)
        {
            if (m_output_dims.d[i] != 0)
            {
                m_output_area *= m_output_dims.d[i];
            }
        }
        // 3. malloc
        checkRuntime(cudaMalloc(&m_output_src_device, m_param.batch_size * m_output_area * sizeof(float)));
        checkRuntime(cudaMalloc(&m_output_src_transpose_device, m_param.batch_size * m_output_area * sizeof(float)));
        // 4. cal affine matrix
        float a = float(m_param.dst_h) / m_param.src_h;
        float b = float(m_param.dst_w) / m_param.src_w;
        float scale = a < b ? a : b;
        cv::Mat src2dst = (cv::Mat_<float>(2, 3) << scale, 0.f, (-scale * m_param.src_w + m_param.dst_w + scale - 1) * 0.5,
            0.f, scale, (-scale * m_param.src_h + m_param.dst_h + scale - 1) * 0.5);
        cv::Mat dst2src = cv::Mat::zeros(2, 3, CV_32FC1);
        cv::invertAffineTransform(src2dst, dst2src);

        m_dst2src.v0 = dst2src.ptr<float>(0)[0];
        m_dst2src.v1 = dst2src.ptr<float>(0)[1];
        m_dst2src.v2 = dst2src.ptr<float>(0)[2];
        m_dst2src.v3 = dst2src.ptr<float>(1)[0];
        m_dst2src.v4 = dst2src.ptr<float>(1)[1];
        m_dst2src.v5 = dst2src.ptr<float>(1)[2];

        this->check();

        return true;
    }

    void YOLOV8::preprocess(const std::vector<cv::Mat>& imgsBatch)
    {
        // 1.resize
        resizeDevice(m_param.batch_size, m_input_src_device, m_param.src_w, m_param.src_h,
            m_input_resize_device, m_param.dst_w, m_param.dst_h, 114, m_dst2src);

        // 2. bgr2rgb
        bgr2rgbDevice(m_param.batch_size, m_input_resize_device, m_param.dst_w, m_param.dst_h,
            m_input_rgb_device, m_param.dst_w, m_param.dst_h);

        // 3. norm:scale mean std
        normDevice(m_param.batch_size, m_input_rgb_device, m_param.dst_w, m_param.dst_h,
            m_input_norm_device, m_param.dst_w, m_param.dst_h, m_param);

        // 4. hwc2chw
        hwc2chwDevice(m_param.batch_size, m_input_norm_device, m_param.dst_w, m_param.dst_h,
            m_input_hwc_device, m_param.dst_w, m_param.dst_h);
    }


    std::vector<std::vector<Box>> YOLOV8::postprocess(const std::vector<cv::Mat>& imgsBatch)
    {
        std::vector<std::vector<Box>> objectss;
        objectss.resize(imgsBatch.size());

        // transpose
        yolo::transposeDevice(m_param, m_output_src_device, m_total_objects, 4 + m_param.num_class, m_total_objects * (4 + m_param.num_class),
            m_output_src_transpose_device, 4 + m_param.num_class, m_total_objects);

        // decode
        yolo::decodeDevice(m_param, m_output_src_transpose_device, 4 + m_param.num_class, m_total_objects, m_output_area,
            m_output_objects_device, m_output_objects_width, m_param.topK);

        // nms
        //nmsDeviceV1(m_param, m_output_objects_device, m_output_objects_width, m_param.topK, m_param.topK * m_output_objects_width + 1);
        nmsDeviceV2(m_param, m_output_objects_device, m_output_objects_width, m_param.topK, m_param.topK * m_output_objects_width + 1, m_output_idx_device, m_output_conf_device);

        // copy result from gpu to cpu
        checkRuntime(cudaMemcpy(m_output_objects_host, m_output_objects_device, m_param.batch_size * sizeof(float) * (1 + 7 * m_param.topK), cudaMemcpyDeviceToHost));

        // transform to source image coordinate,
        for (size_t bi = 0; bi < imgsBatch.size(); bi++)
        {
            int num_boxes = std::min((int)(m_output_objects_host + bi * (m_param.topK * m_output_objects_width + 1))[0], m_param.topK);
            for (size_t i = 0; i < num_boxes; i++)
            {
                float* ptr = m_output_objects_host + bi * (m_param.topK * m_output_objects_width + 1) + m_output_objects_width * i + 1;
                int keep_flag = (int) ptr[6];
                if (keep_flag != 0)
                {
                    // yolov35678
                    float x_lt = m_dst2src.v0 * ptr[0] + m_dst2src.v1 * ptr[1] + m_dst2src.v2; // left & top
                    float y_lt = m_dst2src.v3 * ptr[0] + m_dst2src.v4 * ptr[1] + m_dst2src.v5;
                    float x_rb = m_dst2src.v0 * ptr[2] + m_dst2src.v1 * ptr[3] + m_dst2src.v2; // right & bottom
                    float y_rb = m_dst2src.v3 * ptr[2] + m_dst2src.v4 * ptr[3] + m_dst2src.v5;
                    // yolov4
                    //float x_lt = m_dst2src.v0 * ptr[0] * m_param.dst_w + m_dst2src.v1 * ptr[1] * m_param.dst_h + m_dst2src.v2; // left & top
                    //float y_lt = m_dst2src.v3 * ptr[0] * m_param.dst_w + m_dst2src.v4 * ptr[1] * m_param.dst_h + m_dst2src.v5;
                    //float x_rb = m_dst2src.v0 * ptr[2] * m_param.dst_w + m_dst2src.v1 * ptr[3] * m_param.dst_h + m_dst2src.v2; // right & bottom
                    //float y_rb = m_dst2src.v3 * ptr[2] * m_param.dst_w + m_dst2src.v4 * ptr[3] * m_param.dst_h + m_dst2src.v5;

                    objectss[bi].emplace_back(x_lt, y_lt, x_rb, y_rb, ptr[4], (int)ptr[5]);
                }
            }
        }
        return objectss;
    }
}